import os
import random
import torch
import soundfile as sf
from ASVBaselineTool.feature_tools import *
#获取label和文件名对应的字典
def getlabelmeta(labelpath):
    with open(labelpath, "r") as f:
        Labels = f.readlines()
    AFnames = [item.split(" ")[0] for item in Labels]
    ALabels = [item.split(" ")[-1].strip() for item in Labels]
    # label meta
    LabelMeta = {}
    for f, l in zip(AFnames, ALabels):
        if l == "genuine":
            l = 1
        elif l == "fake":
            l = 0
        LabelMeta[f] = l
    return LabelMeta

def feature_select(feature,wav:torch.Tensor):
    if feature=="spec_2048":
        return get_SPEC_2048(wav)
    elif feature=="spec_1024":
        return get_SPEC_1024(wav)
    elif feature=="mfcc_40":
        return get_MFCC_40(wav)
    elif feature=="mfcc_80":
        return get_MFCC_80(wav)
    elif feature=="lfcc_70":
        return get_LFCC_70(wav)
    elif feature=="lfcc_60":
        return get_LFCC_60(wav)
    elif feature=="raw":
        wav=wav.reshape(1,-1)
        return wav



#计算攻击成功率
# 模型, 对抗样本保存的路径 clean语音的路径 训练集的标签 特征提取方式
def CalASR(model,advsavepath,cleandatapath,trainlabelpath,feature:str,desc:str):
    #首先获取labelmeta
    labelmeta=getlabelmeta(trainlabelpath)
    #
    allvalidf=0
    alltranf=0
    alladvfn=os.listdir(advsavepath)
    for fn in alladvfn:
        clean,sr=sf.read(os.path.join(cleandatapath,fn))
        clean=torch.Tensor(clean)
        cleanfea=feature_select(feature,clean)
        cleanfea=cleanfea.cuda()
        # print(cleanfea)
        model.eval()
        batch_x = model(cleanfea)
        init_pred = torch.max(batch_x, dim=1)[1].item()  # 计算最大预测
        #如果本来就预测错误直接pass
        if init_pred!=labelmeta[fn]:
            print("原始分冷错误")
            continue
        else:
            allvalidf+=1
            pertub,sr=sf.read(os.path.join(advsavepath,fn))
            pertub=torch.Tensor(pertub)
            pertubfea=feature_select(feature,pertub)
            pertubfea=pertubfea.cuda()
            batch_x=model(pertubfea)
            init_pred = torch.max(batch_x, dim=1)[1].item()  # 计算最大预测
            if init_pred!=labelmeta[fn]:
                alltranf+=1
    print(desc)
    print("攻击迁移成功率:",str(alltranf/allvalidf))
    return


def CalClip40000TransAtKSuccess(model,source_length,target_length,advsavepath,cleandatapath,trainlabelpath,feature:str,desc:str):
    print("测试剪切的迁移性：原始样本长度{},剪切后样本长度{}".format(source_length,target_length))
    assert source_length==64600
    assert target_length==40000
    #首先获取labelmeta
    labelmeta=getlabelmeta(trainlabelpath)
    #
    allvalidf=0
    alltranf=0
    alladvfn=os.listdir(advsavepath)
    for fn in alladvfn:
        clean,sr=sf.read(os.path.join(cleandatapath,fn))
        clean=torch.Tensor(clean)
        cleanfea=feature_select(feature,clean)
        cleanfea=cleanfea.cuda()
        # print(cleanfea)
        model.eval()
        batch_x = model(cleanfea)
        init_pred = torch.max(batch_x, dim=1)[1].item()  # 计算最大预测
        #如果本来就预测错误直接pass
        if init_pred!=labelmeta[fn]:
            print("原始分冷错误")
            continue
        else:
            allvalidf+=1
            pertub,sr=sf.read(os.path.join(advsavepath,fn))
            pertub=pertub[:40000]
            pertub=torch.Tensor(pertub)
            pertubfea=feature_select(feature,pertub)
            pertubfea=pertubfea.cuda()
            batch_x=model(pertubfea)
            init_pred = torch.max(batch_x, dim=1)[1].item()  # 计算最大预测
            if init_pred!=labelmeta[fn]:
                alltranf+=1
    print(desc)
    print("攻击迁移成功率:",str(alltranf/allvalidf))
    return


